package com.chatplus.application.aiprocessor.platform.image.mj;

import cn.bugstack.openai.session.Configuration;
import com.chatplus.application.aiprocessor.handler.MJWebSocketHandler;
import com.chatplus.application.aiprocessor.platform.image.ImgAiProcessorService;
import com.chatplus.application.aiprocessor.provider.ImgAiProcessorServiceProvider;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.util.FileUtils;
import com.chatplus.application.common.util.OkHttpClientUtil;
import com.chatplus.application.common.util.PlusJsonUtils;
import com.chatplus.application.domain.dto.ApiKeyDto;
import com.chatplus.application.domain.dto.MJJobDto;
import com.chatplus.application.domain.dto.extend.ImgResultDto;
import com.chatplus.application.domain.entity.draw.MjJobEntity;
import com.chatplus.application.domain.notification.MJJobNotification;
import com.chatplus.application.domain.request.MjCallbackNotifyRequest;
import com.chatplus.application.domain.request.MjImageJobRequest;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.enumeration.ImageTaskTypeEnum;
import com.chatplus.application.service.draw.MjJobService;
import com.chatplus.application.web.notification.NotificationPublisher;
import okhttp3.Call;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * MJ 绘画处理器
 **/
@Service(value = ImgAiProcessorServiceProvider.SERVICE_NAME_PRE + "MJ")
public class MjImageProcessor extends ImgAiProcessorService {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(MjImageProcessor.class);
    private final NotificationPublisher notificationPublisher;
    private final MjJobService mjJobService;
    private final MJWebSocketHandler mjWebSocketHandler;

    public MjImageProcessor(NotificationPublisher notificationPublisher,
                            MjJobService mjJobService,
                            MJWebSocketHandler mjWebSocketHandler) {
        this.notificationPublisher = notificationPublisher;
        this.mjJobService = mjJobService;
        this.mjWebSocketHandler = mjWebSocketHandler;
    }

    public ImgResultDto process(Object prompt) {
        MjImageJobRequest mjImageJobRequest = (MjImageJobRequest) prompt;
        MJJobNotification mjJobNotification = new MJJobNotification();
        ImageTaskTypeEnum taskTypeEnum = mjImageJobRequest.getTaskType();
        mjJobNotification.setMjJobId(mjImageJobRequest.getMjJobId());
        MJJobDto mjJobDto = new MJJobDto();
        mjJobDto.setState(mjImageJobRequest.getMjJobId() + "");
        List<String> base64Array = handleBase64Array(mjImageJobRequest.getImgArr());
        switch (taskTypeEnum) {
            case TEXT_2_IMG:
                mjJobDto.setPrompt(getPrompt(mjImageJobRequest));
                mjJobDto.setBase64Array(base64Array);
                break;
            case BLEND:
                // 请上传两张以上的图片，最多不超过五张，超过五张图片请使用文生图功能
                if (CollectionUtils.isEmpty(base64Array) || base64Array.size() < 2 || base64Array.size() > 5) {
                    throw new BadRequestException("请上传两张以上的有效图片，最多不超过五张，超过五张图片请使用文生图功能");
                }
                mjJobDto.setBase64Array(base64Array);
                break;
            case SWAP_FACE:
                if (CollectionUtils.isEmpty(base64Array) || base64Array.size() != 2) {
                    throw new BadRequestException("换脸操作需要上传两张有效图片");
                }
                mjJobDto.setSourceBase64(base64Array.get(0));
                mjJobDto.setTargetBase64(base64Array.get(1));
                break;
            case UPSCALE:
                mjJobDto.setCustomId(String.format("MJ::JOB::upsample::%d::%s", mjImageJobRequest.getIndex(), mjImageJobRequest.getMessageHash()));
                mjJobDto.setTaskId(mjImageJobRequest.getHandleTaskId());
                break;
            case VARIATION:
                mjJobDto.setCustomId(String.format("MJ::JOB::variation::%d::%s", mjImageJobRequest.getIndex(), mjImageJobRequest.getMessageHash()));
                mjJobDto.setTaskId(mjImageJobRequest.getHandleTaskId());
                break;
            case ACTION:
                mjJobDto.setCustomId(mjImageJobRequest.getCustomId());
                mjJobDto.setTaskId(mjImageJobRequest.getHandleTaskId());
                break;
            default:
                throw new BadRequestException("未知的任务类型");
        }
        // 设置回调地址
        ApiKeyDto apiKeyDto = openApiKey.getFirst();
        if (StringUtils.isNotEmpty(apiKeyDto.getNotifyUrl())) {
            mjJobDto.setNotifyHook(apiKeyDto.getNotifyUrl());
        }
        mjJobNotification.setApiKeyDto(apiKeyDto);
        mjJobNotification.setMjJobDto(mjJobDto);
        notificationPublisher.publish(mjJobNotification);
        mjWebSocketHandler.sendTaskUpdatedMessage(mjImageJobRequest.getUserId(), 0);
        return null;
    }

    private List<String> handleBase64Array(List<String> imgList) {
        if (CollectionUtils.isEmpty(imgList)) {
            return Collections.emptyList();
        }
        List<String> base64Array = new ArrayList<>(imgList.size());
        for (String url : imgList) {
            if (StringUtils.isEmpty(url)) {
                continue;
            }
            String base64 = FileUtils.getBase64FromImageURL(url);
            if (StringUtils.isNotEmpty(base64)) {
                base64Array.add(String.format("data:image/png;base64,%s", base64));
            }
        }
        return base64Array;
    }

    @Override
    public int getImageProgress(Long id) {
        instance();
        int progressInt = -1;
        MjJobEntity mjJobEntity = mjJobService.getById(id);
        if (mjJobEntity == null) {
            throw new BadRequestException("任务不存在: " + id);
        }
        if (mjJobEntity.getProgress() == 100) {
            notifyUpdateTask(mjJobEntity.getUserId());
            return 100;
        }
        ApiKeyDto apiKeyDto = openApiKey.getFirst();
        OkHttpClient okHttpClient = OkHttpClientUtil.getOkHttpClient(apiKeyDto.getProxyUrl(), true);
        Request request = new Request.Builder()
                .addHeader("Content-Type", Configuration.APPLICATION_JSON)
                .addHeader("Accept", Configuration.APPLICATION_JSON)
                .url(apiKeyDto.getUrl() + String.format("/mj/task/%s/fetch", mjJobEntity.getTaskId()))
                .get()
                .addHeader("Authorization", "Bearer " + apiKeyDto.getApiKey())
                .build();
        Call call = okHttpClient.newCall(request);
        try (Response response = call.execute()) {
            if (response.isSuccessful() && response.body() != null) {
                String body = response.body().string();
                LOGGER.message("MJ图片获取画图进度").context("id", id).context("response", response).context("body", body).info();
                MjCallbackNotifyRequest mjCallbackNotifyRequest = PlusJsonUtils.parseObject(body, MjCallbackNotifyRequest.class);
                mjJobService.handleNotify(mjJobEntity, mjCallbackNotifyRequest);
                mjJobService.updateById(mjJobEntity);
            }
        } catch (Exception e) {
            LOGGER.message("获取MJ图片进度失败").exception(e).error();
        } finally {
            notifyUpdateTask(mjJobEntity.getId());
        }
        return progressInt;
    }

    @Override
    public void notifyUpdateTask(Long jobId) {
        MjJobEntity mjJobEntity = mjJobService.getById(jobId);
        mjWebSocketHandler.sendTaskUpdatedMessage(mjJobEntity.getUserId(), mjJobEntity.getProgress());
    }

    @Override
    public AiPlatformEnum getChannel() {
        return AiPlatformEnum.MJ;
    }

    private String getPrompt(MjImageJobRequest mjImageJobRequest) {
        String prompt = mjImageJobRequest.getPrompt();
        if (StringUtils.isNotEmpty(mjImageJobRequest.getRate()) && !prompt.contains("--ar")) {
            prompt += " --ar " + mjImageJobRequest.getRate();
        }
        if (mjImageJobRequest.getSeed() != null && mjImageJobRequest.getSeed() > 0 && !prompt.contains("--seed")) {
            prompt += " --seed " + mjImageJobRequest.getSeed();
        }
        if (mjImageJobRequest.getStylize() != null && mjImageJobRequest.getStylize() > 0 && !prompt.contains("--s") && !prompt.contains("--stylize")) {
            prompt += " --s " + mjImageJobRequest.getStylize();
        }
        if (mjImageJobRequest.getChaos() != null && mjImageJobRequest.getChaos() > 0 && !prompt.contains("--c") && !prompt.contains("--chaos")) {
            prompt += " --c " + mjImageJobRequest.getChaos();
        }
        if (StringUtils.isNotEmpty(mjImageJobRequest.getImg()) && !prompt.contains("--img")) {
            prompt = mjImageJobRequest.getImg() + " " + prompt;
            if (mjImageJobRequest.getWeight() > 0 && !prompt.contains("--iw")) {
                prompt += " --iw " + mjImageJobRequest.getWeight();
            }
        }
        if (mjImageJobRequest.getRaw() != null && Boolean.TRUE.equals(mjImageJobRequest.getRaw()) && !prompt.contains("--style")) {
            prompt += " --style raw";
        }
        if (mjImageJobRequest.getQuality() != null && mjImageJobRequest.getQuality() > 0 && !prompt.contains("--q")) {
            prompt += " --q " + mjImageJobRequest.getQuality();
        }
        if (StringUtils.isNotEmpty(mjImageJobRequest.getNegPrompt()) && !prompt.contains("--no")) {
            prompt += " --no " + mjImageJobRequest.getNegPrompt();
        }
        if (mjImageJobRequest.getTile() != null && Boolean.TRUE.equals(mjImageJobRequest.getTile()) && !prompt.contains("--tile")) {
            prompt += " --tile";
        }

        if (StringUtils.isNotEmpty(mjImageJobRequest.getCref())) {
            prompt += " --cref " + mjImageJobRequest.getCref();
            if (mjImageJobRequest.getCw() > 0) {
                prompt += " --cw " + mjImageJobRequest.getCw();
            } else {
                prompt += " --cw 100";
            }
        }
        if (StringUtils.isNotEmpty(mjImageJobRequest.getSref())) {
            prompt += " --sref " + mjImageJobRequest.getSref();
        }

        if (StringUtils.isNotEmpty(mjImageJobRequest.getModel()) && !prompt.contains("--v") && !prompt.contains("--niji")) {
            prompt += mjImageJobRequest.getModel();
        }
        // 处理融图和换脸的提示词
        if (mjImageJobRequest.getTaskType() == ImageTaskTypeEnum.SWAP_FACE || mjImageJobRequest.getTaskType() == ImageTaskTypeEnum.BLEND) {
            prompt = String.format("%s:%s", mjImageJobRequest.getTaskType().getValue(), String.join(",", mjImageJobRequest.getImgArr()));
        }

        return prompt;
    }

    @Override
    public long getRunningJobCount(Long userId) {
        return mjJobService.getRunningJobCount(userId);
    }
}
